import torch
from torch_geometric.nn import global_mean_pool
from models.new_mlp import NewMLP as MLP
from torch.nn import Linear
import models.DMPNN_geometric as mpn

def mpnn_block(in_f, out_f, *args, **kwargs):
    return mpn.DMPNN((in_f, in_f), *args)

def reverse_index(idx_tensor):
    reversed_idx_tensor = idx_tensor.clone()
    reversed_idx_tensor += 1
    reversed_idx_tensor[reversed_idx_tensor % 2 == 0] -= 2
    return reversed_idx_tensor.cpu()

class transform_network(torch.nn.Module):
    def __init__(self, latent_size=50, net_width=[200, 200, 200, 200]):
        super().__init__()

        self.latent_size = latent_size

        self.MLP = []

        self.MLP.append(torch.nn.Linear(latent_size, net_width[0]))
        self.MLP.append(torch.nn.ReLU())
        for i in range(len(net_width) - 1):
            self.MLP.append(torch.nn.Linear(net_width[i], net_width[i + 1]))
            self.MLP.append(torch.nn.ReLU())
        self.MLP.append(torch.nn.Linear(net_width[-1], latent_size))
        self.MLP = torch.nn.Sequential(*self.MLP)

    def forward(self, data):
        out = self.MLP(data)
        return out
    
class HeadNetwork(torch.nn.Module):
    def __init__(self, in_f, hidden=None, out_f=1, dropout=0.2, task_name="target"):
        super().__init__()
        self.task_name = task_name

        self.in_f = in_f
        self.hidden = [in_f, in_f//2, in_f//4] if hidden == None else hidden
        self.out_f = out_f
        self.dropout = dropout
        if hidden == []:
            self.ffn = torch.nn.Identity(self.in_f)
            self.ffn_fin = Linear(self.in_f, self.out_f, bias=True)
        else:
            self.ffn = MLP([self.in_f, *self.hidden], dropout=self.dropout)
            self.ffn_fin = Linear(self.in_f // 4, self.out_f, bias=True)
            
    def forward(self, x):
        if x.shape[0] == 1:
            x_out = self.ffn_fin(self.ffn(x))
        else:
            x_out = self.ffn_fin(self.ffn(x))

        return x_out            


class BottleneckNetwork(torch.nn.Module):
    def __init__(self, in_f, hidden, dropout, task_name="target"):
        super().__init__()
        self.task_name = task_name

        encoder_f = [in_f, *hidden]
        decoder_f = list(reversed(encoder_f))

        self.mlp_e = MLP(encoder_f, dropout=dropout)
        self.mlp_d = MLP(decoder_f, dropout=dropout)

    def forward(self, input):
        x_en_out = self.mlp_e(input)
        x_de_out = self.mlp_d(x_en_out)

        return x_en_out, x_de_out

class SingleNetwork_front(torch.nn.Module):
    def __init__(self, in_f, out_f, depth=2, heads=1, dropout=0.2, task_name="target"):
        super().__init__()
        self.in_f = in_f
        self.out_f = out_f
        self.depth = depth

        self.mpn_sub = mpn.DMPNN_front(in_f, self.out_f * 2, depth=self.depth)

    def get_datas(self, data, idxs=None):
        if idxs is None:
            return data.x, data.batch, data.edge_index, data.edge_index_cycle, data.edge_attr, data.edge_attr_cycle, data.atom_map_idx, data.b_keep_idx, data.edge_attr_main, data.x_main
        else:
            idxs_isin = torch.isin(data.batch, idxs[0])

            x = data.x[idxs_isin]

            batch = data.batch[torch.where(idxs_isin)]
            bin = torch.bincount(batch)
            batch = torch.masked_select(bin, bin.bool())
            batch = torch.arange(len(batch)).to(batch).repeat_interleave(batch, dim=0)

            node_idxs = torch.where(idxs_isin == True)[0]
            edge_idxs_isin = torch.isin(data.edge_index, node_idxs)[0]
            edge_index = data.edge_index.T[edge_idxs_isin].T
            edge_index = (edge_index.reshape(-1, 1) == node_idxs).int().argmax(dim=1).reshape(2, -1)
            edge_attr = data.edge_attr[edge_idxs_isin]

            edge_idxs_cycle_isin = torch.isin(data.edge_index_cycle, node_idxs)[0]
            edge_index_cycle = data.edge_index_cycle.T[edge_idxs_cycle_isin].T
            edge_index_cycle = (edge_index_cycle.reshape(-1, 1) == node_idxs).int().argmax(dim=1).reshape(2, -1)
            edge_attr_cycle = data.edge_attr_cycle[edge_idxs_cycle_isin]

            b_keep_idx = data.b_keep_idx[edge_idxs_isin]
            edge_attr_main = data.edge_attr_main[edge_idxs_isin]
            x_main = data.x_main[idxs_isin]

            return x, batch, edge_index, edge_index_cycle, edge_attr, edge_attr_cycle, data.atom_map_idx, b_keep_idx, edge_attr_main, x_main

    def forward(self, data, idxs=None):
        x, batch, edge_index, edge_index_cycle, edge_attr, edge_attr_cycle, atom_map_idx, b_keep_idx, edge_attr_main, x_main = self.get_datas(data, idxs)

        rev_edge_index = reverse_index(torch.arange(edge_index.shape[1]))
        rev_edge_index_cycle = reverse_index(b_keep_idx)

        outs = self.mpn_sub(x, batch, edge_index, edge_index_cycle, edge_attr, rev_edge_index, atom_map_idx=None, b_keep_idx=b_keep_idx, rev_edge_index_cycle=rev_edge_index_cycle)
        return outs

class SingleNetwork_back(torch.nn.Module):
    def __init__(self, in_f, out_f, depth=2, heads=1, dropout=0.2, task_name="target"):
        super().__init__()
        self.in_f = in_f
        self.out_f = out_f
        self.depth = depth
        self.heads = heads
        self.dropout = dropout

        self.mpn_sub = mpn.DMPNN_back(in_f, self.out_f * 2, depth=self.depth)
        self.mlp_init = MLP([self.out_f * 2, self.out_f * 2, self.out_f], dropout=self.dropout)

    def forward(self, inputs, perturbation=None):
        batch = inputs[1]
        if perturbation is not None:
            inputs = list(inputs)
            inputs[0] = perturbation(inputs[0])
            inputs = tuple(inputs)

        x_mpn_out_sub_ncycle, edge_attr_out_sub_ncycle = self.mpn_sub(*inputs)

        mpn_out_attn = global_mean_pool(x_mpn_out_sub_ncycle, batch)
        mpn_out_cat = self.mlp_init(mpn_out_attn)

        return mpn_out_cat

    def node_out(self, inputs, perturbation=None):
        batch = inputs[1]
        if perturbation is not None:
            inputs = list(inputs)
            inputs[0] = perturbation(inputs[0]) 
            inputs = tuple(inputs)

        x_mpn_out_sub_ncycle, edge_attr_out_sub_ncycle = self.mpn_sub(*inputs)

        mpn_out_attn = global_mean_pool(x_mpn_out_sub_ncycle, batch)
        mpn_out_cat = self.mlp_init(mpn_out_attn)

        return mpn_out_cat, x_mpn_out_sub_ncycle


class SingleNetwork(torch.nn.Module):
    def __init__(self, in_f, out_f, depth=4, heads=1, dropout=0.2, task_name="target", node=False):
        super().__init__()
        self.front = SingleNetwork_front(in_f, out_f, depth-1, heads, dropout, task_name)
        self.node = node
        if self.node:
            self.back = SingleNetwork_back(in_f, out_f, 1, heads, dropout, task_name)            
        else:
            self.back = SingleNetwork_back(in_f, out_f, 1, heads, dropout, task_name)

    def forward(self, data, idxs=None):
        outs = self.front(data, idxs)
        if self.node:
            mpn_out_cat, node_out = self.back.node_out(outs)
            return mpn_out_cat, node_out
        else:
            mpn_out_cat = self.back(outs)
            return mpn_out_cat